import asyncio
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence

from autogen_core import (
    AgentId,
    AgentRuntime,
    AgentType,
    CancellationToken,
    ComponentBase,
    SingleThreadedAgentRuntime,
    TypeSubscription,
)
from pydantic import BaseModel, ValidationError

from autogen_agentchat.base import ChatAgent, TaskResult, Team, TerminationCondition
from autogen_agentchat.messages import (
    BaseAgentEvent,
    BaseChatMessage,
    MessageFactory,
    ModelClientStreamingChunkEvent,
    StopMessage,
    StructuredMessage,
    TextMessage,
)
from autogen_agentchat.state import TeamState
from autogen_agentchat.teams._group_chat._chat_agent_container import ChatAgentContainer
from autogen_agentchat.teams._group_chat._events import (
    GroupChatPause,
    GroupChatReset,
    GroupChatResume,
    GroupChatStart,
    GroupChatTermination,
    SerializableException,
)
from autogen_agentchat.teams._group_chat._sequential_routed_agent import SequentialRoutedAgent


class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
    """
    群组聊天团队的基类。
    要实现一个群组聊天团队，首先需创建 :class:`BaseGroupChatManager` 的子类，然后
    创建使用该群组聊天管理器的 :class:`BaseGroupChat` 子类。
    """
    component_type = "team"

    def __init__(
        self,
        participants: List[ChatAgent],
        group_chat_manager_name: str,
        group_chat_manager_class: type[SequentialRoutedAgent],
        termination_condition: TerminationCondition | None = None,
        max_turns: int | None = None,
        runtime: AgentRuntime | None = None,
        custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
        emit_team_events: bool = False,
    ):
        if len(participants) == 0:
            raise ValueError("At least one participant is required.")
        if len(participants) != len(set(participant.name for participant in participants)):
            raise ValueError("The participant names must be unique.")
        self._participants = participants
        self._base_group_chat_manager_class = group_chat_manager_class
        self._termination_condition = termination_condition
        self._max_turns = max_turns
        self._message_factory = MessageFactory()
        if custom_message_types is not None:
            for message_type in custom_message_types:
                self._message_factory.register(message_type)

        for agent in participants:
            for message_type in agent.produced_message_types:
                try:
                    is_registered = self._message_factory.is_registered(message_type)  # type: ignore[reportUnknownArgumentType]
                    if issubclass(message_type, StructuredMessage) and not is_registered:
                        self._message_factory.register(message_type)  # type: ignore[reportUnknownArgumentType]
                except TypeError:
                    # Not a class or not a valid subclassable type (skip)
                    pass

        self._team_id = str(uuid.uuid4())


        self._group_chat_manager_name = group_chat_manager_name
        self._participant_names: List[str] = [participant.name for participant in participants]
        self._participant_descriptions: List[str] = [participant.description for participant in participants]
        # The group chat topic type is used for broadcast communication among all participants and the group chat manager.
        self._group_topic_type = f"group_topic_{self._team_id}"
        # The group chat manager topic type is used for direct communication with the group chat manager.
        self._group_chat_manager_topic_type = f"{self._group_chat_manager_name}_{self._team_id}"
        # The participant topic types are used for direct communication with each participant.
        self._participant_topic_types: List[str] = [
            f"{participant.name}_{self._team_id}" for participant in participants
        ]
        # The output topic type is used for emitting streaming messages from the group chat.
        # The group chat manager will relay the messages to the output message queue.
        self._output_topic_type = f"output_topic_{self._team_id}"

        # The queue for collecting the output messages.
        self._output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = (
            asyncio.Queue()
        )

        # Create a runtime for the team.
        if runtime is not None:
            self._runtime = runtime
            self._embedded_runtime = False
        else:
            # Use a embedded single-threaded runtime for the group chat.
            # Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
            self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
            self._embedded_runtime = True

        # Flag to track if the group chat has been initialized.
        self._initialized = False

        # Flag to track if the group chat is running.
        self._is_running = False

        # Flag to track if the team events should be emitted.
        self._emit_team_events = emit_team_events

    @abstractmethod
    def _create_group_chat_manager_factory(
        self,
        name: str,
        group_topic_type: str,
        output_topic_type: str,
        participant_topic_types: List[str],
        participant_names: List[str],
        participant_descriptions: List[str],
        output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
        termination_condition: TerminationCondition | None,
        max_turns: int | None,
        message_factory: MessageFactory,
    ) -> Callable[[], SequentialRoutedAgent]: ...

    def _create_participant_factory(
        self,
        parent_topic_type: str,
        output_topic_type: str,
        agent: ChatAgent,
        message_factory: MessageFactory,
    ) -> Callable[[], ChatAgentContainer]:
        def _factory() -> ChatAgentContainer:
            container = ChatAgentContainer(parent_topic_type, output_topic_type, agent, message_factory)
            return container

        return _factory

    async def _init(self, runtime: AgentRuntime) -> None:
        # 群组聊天管理器的常量。
        group_chat_manager_agent_type = AgentType(self._group_chat_manager_topic_type)

        # 注册参与者。
        # 使用参与者主题类型作为代理类型。
        for participant, agent_type in zip(self._participants, self._participant_topic_types, strict=True):
            # 注册参与者工厂。
            await ChatAgentContainer.register(
                runtime,
                type=agent_type,
                factory=self._create_participant_factory(
                    self._group_topic_type, self._output_topic_type, participant, self._message_factory
                ),
            )
            # 为参与者添加订阅。
            # 参与者应能够从其自身的主题接收消息。
            await runtime.add_subscription(TypeSubscription(topic_type=agent_type, agent_type=agent_type))
            # 参与者应能够从群组主题接收消息。
            await runtime.add_subscription(TypeSubscription(topic_type=self._group_topic_type, agent_type=agent_type))

        # 注册群组聊天管理器。
        await self._base_group_chat_manager_class.register(
            runtime,
            type=group_chat_manager_agent_type.type,
            factory=self._create_group_chat_manager_factory(
                name=self._group_chat_manager_name,
                group_topic_type=self._group_topic_type,
                output_topic_type=self._output_topic_type,
                participant_names=self._participant_names,
                participant_topic_types=self._participant_topic_types,
                participant_descriptions=self._participant_descriptions,
                output_message_queue=self._output_message_queue,
                termination_condition=self._termination_condition,
                max_turns=self._max_turns,
                message_factory=self._message_factory,
            ),
        )
        # 为群组聊天管理器添加订阅。
        # 群组聊天管理器应能够从其自身的主题接收消息。
        await runtime.add_subscription(
            TypeSubscription(
                topic_type=self._group_chat_manager_topic_type, agent_type=group_chat_manager_agent_type.type
            )
        )
        # 群组聊天管理器应能够从群组主题接收消息。
        await runtime.add_subscription(
            TypeSubscription(topic_type=self._group_topic_type, agent_type=group_chat_manager_agent_type.type)
        )
        # 群组聊天管理器将把输出主题的消息转发到输出消息队列。
        await runtime.add_subscription(
            TypeSubscription(topic_type=self._output_topic_type, agent_type=group_chat_manager_agent_type.type)
        )

        self._initialized = True

    async def run_stream(
        self,
        cancellation_token: CancellationToken | None = None,
    ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
        messages: List[BaseChatMessage] | None = None
        if self._is_running:
            raise ValueError("The team is already running, it cannot run again until it is stopped.")
        self._is_running = True

        if self._embedded_runtime:
            # Start the embedded runtime.
            assert isinstance(self._runtime, SingleThreadedAgentRuntime)
            self._runtime.start()

        if not self._initialized:
            await self._init(self._runtime)

        shutdown_task: asyncio.Task[None] | None = None
        if self._embedded_runtime:

            async def stop_runtime() -> None:
                assert isinstance(self._runtime, SingleThreadedAgentRuntime)
                try:
                    # This will propagate any exceptions raised.
                    await self._runtime.stop_when_idle()
                    # Put a termination message in the queue to indicate that the group chat is stopped for whatever reason
                    # but not due to an exception.
                    await self._output_message_queue.put(
                        GroupChatTermination(
                            message=StopMessage(
                                content="The group chat is stopped.", source=self._group_chat_manager_name
                            )
                        )
                    )
                except Exception as e:
                    # Stop the consumption of messages and end the stream.
                    # NOTE: we also need to put a GroupChatTermination event here because when the runtime
                    # has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
                    # This may not be necessary if the group chat manager is able to handle the exception and put the event in the queue.
                    await self._output_message_queue.put(
                        GroupChatTermination(
                            message=StopMessage(
                                content="An exception occurred in the runtime.", source=self._group_chat_manager_name
                            ),
                            error=SerializableException.from_exception(e),
                        )
                    )

            # Create a background task to stop the runtime when the group chat
            # is stopped or has an exception.
            shutdown_task = asyncio.create_task(stop_runtime())

        try:
            # Run the team by sending the start message to the group chat manager.
            # The group chat manager will start the group chat by relaying the message to the participants
            # and the group chat manager.
            await self._runtime.send_message(
                GroupChatStart(messages=messages),
                recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
                cancellation_token=cancellation_token,
            )
            # Collect the output messages in order.
            output_messages: List[BaseAgentEvent | BaseChatMessage] = []
            stop_reason: str | None = None
            
            # 创建用户输入监听任务
            async def user_input_listener() -> None:
                while self._is_running:
                    try:
                        # 读取用户输入
                        user_input = await asyncio.to_thread(input, "请输入消息: ")
                        if user_input.strip().lower() == 'exit':
                            await self._output_message_queue.put(
                                GroupChatTermination(
                                    message=StopMessage(
                                        content="User requested exit.", source="user"
                                    )
                                )
                            )
                            break
                        # 将用户输入转换为消息并发送给群管理员
                        user_message = TextMessage(content=user_input, source="user")
                        await self._runtime.send_message(
                            user_message,
                            recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
                            cancellation_token=cancellation_token,
                        )
                    except Exception as e:
                        print(f"Error handling user input: {e}")
                        break
            
            # 启动用户输入监听任务
            input_task = asyncio.create_task(user_input_listener())
            
            # Yield the messages until termination
            message_future = asyncio.ensure_future(self._output_message_queue.get())
            if cancellation_token is not None:
                cancellation_token.link_future(message_future)
            
            while True:
                # 同时等待消息队列和用户输入任务
                done, pending = await asyncio.wait(
                    [message_future, input_task],
                    return_when=asyncio.FIRST_COMPLETED,
                )
                
                if message_future in done:
                    try:
                        message = message_future.result()
                        if isinstance(message, GroupChatTermination):
                            # 如果收到终止消息，取消输入任务并退出循环
                            input_task.cancel()
                            if message.error is not None:
                                raise RuntimeError(str(message.error))
                            stop_reason = message.message.content
                            break
                        yield message
                        if isinstance(message, ModelClientStreamingChunkEvent):
                            continue
                        output_messages.append(message)
                    except Exception as e:
                        input_task.cancel()
                        raise e
                    finally:
                        # 为下一条消息创建新的future
                        message_future = asyncio.ensure_future(self._output_message_queue.get())
                        if cancellation_token is not None:
                            cancellation_token.link_future(message_future)
                
                if input_task in done:
                    # 如果输入任务完成，检查是否有异常
                    if not input_task.cancelled():
                        try:
                            input_task.result()
                        except Exception as e:
                            print(f"User input task error: {e}")
                    # 继续处理消息队列
                    if message_future.done():
                        # 如果消息future已完成，创建新的
                        message_future = asyncio.ensure_future(self._output_message_queue.get())
                        if cancellation_token is not None:
                            cancellation_token.link_future(message_future)
            
            # Yield the final result.
            yield TaskResult(messages=output_messages, stop_reason=stop_reason)

        finally:
            try:
                if shutdown_task is not None:
                    # Wait for the shutdown task to finish.
                    # This will propagate any exceptions raised.
                    await shutdown_task
            finally:
                # Clear the output message queue.
                while not self._output_message_queue.empty():
                    self._output_message_queue.get_nowait()

                # Indicate that the team is no longer running.
                self._is_running = False

    async def reset(self) -> None:
        """Reset the team and its participants to their initial state.

        The team must be stopped before it can be reset.

        Raises:
            RuntimeError: If the team has not been initialized or is currently running.

        """

        if not self._initialized:
            await self._init(self._runtime)

        if self._is_running:
            raise RuntimeError("The group chat is currently running. It must be stopped before it can be reset.")
        self._is_running = True

        if self._embedded_runtime:
            # Start the runtime.
            assert isinstance(self._runtime, SingleThreadedAgentRuntime)
            self._runtime.start()

        try:
            # Send a reset messages to all participants.
            for participant_topic_type in self._participant_topic_types:
                await self._runtime.send_message(
                    GroupChatReset(),
                    recipient=AgentId(type=participant_topic_type, key=self._team_id),
                )
            # Send a reset message to the group chat manager.
            await self._runtime.send_message(
                GroupChatReset(),
                recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
            )
        finally:
            if self._embedded_runtime:
                # Stop the runtime.
                assert isinstance(self._runtime, SingleThreadedAgentRuntime)
                await self._runtime.stop_when_idle()

            # Reset the output message queue.
            while not self._output_message_queue.empty():
                self._output_message_queue.get_nowait()

            # Indicate that the team is no longer running.
            self._is_running = False

    async def pause(self) -> None:
        """Pause its participants when the team is running by calling their
        """
        if not self._initialized:
            raise RuntimeError("The group chat has not been initialized. It must be run before it can be paused.")

        # Send a pause message to all participants.
        for participant_topic_type in self._participant_topic_types:
            await self._runtime.send_message(
                GroupChatPause(),
                recipient=AgentId(type=participant_topic_type, key=self._team_id),
            )
        # Send a pause message to the group chat manager.
        await self._runtime.send_message(
            GroupChatPause(),
            recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
        )

    async def resume(self) -> None:
        """Resume its participants when the team is running and paused by calling their
        """
        if not self._initialized:
            raise RuntimeError("The group chat has not been initialized. It must be run before it can be resumed.")

        # Send a resume message to all participants.
        for participant_topic_type in self._participant_topic_types:
            await self._runtime.send_message(
                GroupChatResume(),
                recipient=AgentId(type=participant_topic_type, key=self._team_id),
            )
        # Send a resume message to the group chat manager.
        await self._runtime.send_message(
            GroupChatResume(),
            recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
        )

    async def save_state(self) -> Mapping[str, Any]:
        """Save the state of the group chat team.
        """
        if not self._initialized:
            await self._init(self._runtime)

        # Store state of each agent by their name.
        # NOTE: we don't use the agent ID as the key here because we need to be able to decouple
        # the state of the agents from their identities in the agent runtime.
        agent_states: Dict[str, Mapping[str, Any]] = {}
        # Save the state of all participants.
        for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
            agent_id = AgentId(type=agent_type, key=self._team_id)
            # NOTE: We are using the runtime's save state method rather than the agent instance's
            # save_state method because we want to support saving state of remote agents.
            agent_states[name] = await self._runtime.agent_save_state(agent_id)
        # Save the state of the group chat manager.
        agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
        agent_states[self._group_chat_manager_name] = await self._runtime.agent_save_state(agent_id)
        return TeamState(agent_states=agent_states).model_dump()

    async def load_state(self, state: Mapping[str, Any]) -> None:
        """Load an external state and overwrite the current state of the group chat team.
        """
        if not self._initialized:
            await self._init(self._runtime)

        if self._is_running:
            raise RuntimeError("The team cannot be loaded while it is running.")
        self._is_running = True

        try:
            team_state = TeamState.model_validate(state)
            # Load the state of all participants.
            for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
                agent_id = AgentId(type=agent_type, key=self._team_id)
                if name not in team_state.agent_states:
                    raise ValueError(f"Agent state for {name} not found in the saved state.")
                await self._runtime.agent_load_state(agent_id, team_state.agent_states[name])
            # Load the state of the group chat manager.
            agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
            if self._group_chat_manager_name not in team_state.agent_states:
                raise ValueError(f"Agent state for {self._group_chat_manager_name} not found in the saved state.")
            await self._runtime.agent_load_state(agent_id, team_state.agent_states[self._group_chat_manager_name])

        except ValidationError as e:
            raise ValueError(
                "Invalid state format. The expected state format has changed since v0.4.9. "
                "Please read the release note on GitHub."
            ) from e

        finally:
            # Indicate that the team is no longer running.
            self._is_running = False